Skip to content

Conversation

@QuentinJGMace
Copy link
Collaborator

@QuentinJGMace QuentinJGMace commented Sep 29, 2025

  • adds vbert
  • fixes training with multiple hardnegs

Still to do:
Modify all negatives loss, not just the ones used for vbert

QuentinJGMace and others added 13 commits September 30, 2025 09:53
* modeling

* update modeling

* update token id default

* init files

* remove vllama + update torch lower bound for cpu

* back to normal transformer bound

* clean

* Update colpali_engine/models/__init__.py

---------

Co-authored-by: QuentinJGMace <[email protected]>
Copy link
Collaborator

@mlconti1 mlconti1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly comments about the form, overall LGTM!

Comment on lines -15 to +18
ColQwen2_5Omni,
ColQwen2_5OmniProcessor,
# ColQwen2_5Omni,
# ColQwen2_5OmniProcessor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment to the README if ColQwen 2.5 Omni is not supported anymore


# Process queries.
queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries]
# queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented lines if not useful

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually usefull, in modernvbert self.processor.query_prefix is "" but it is useful if somebody wants to reproduce other older models.
Thanks for flagging it out !

# Process queries.
queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries]
# queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries]
queries = [q + self.processor.query_augmentation_token * 10 for q in queries] if is_str else queries
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put 10 into a constant (e.g. N_AUGMENTATION_TOKENS)

Comment on lines +126 to +127
else:
proc_batch[k] = v
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary

query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"])
query_outputs = model(**{k[6:]: v for k, v in inputs.items() if k.startswith("query")})
# feed only kwargs with 'doc_' prefix
doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define var/constant for len("doc:")

"""
Helper function to reshape negative doc inputs to (batch_size * num_neg_docs, ...)
"""
neg_doc_inputs = {k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define var/constant for 8

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could rename variables for more clarity and use constants, and add doc

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save as test_bi_losses

assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"

# # Check if the maximum scores per row are in the diagonal of the matrix score
# assert (scores.argmax(dim=1) == torch.arange(len(ds), device=scores.device)).all()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this commented out?

@athrael-soju

This comment was marked as resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants